package com.yk.dataGatherer.websocket;

import cn.dev33.satoken.exception.SaTokenException;
import cn.dev33.satoken.stp.StpUtil;
import com.yk.dataGatherer.config.MyEndpointConfigure;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.io.Serializable;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;

/**
 * WebSocket
 *
 * @author lmx
 */
@Slf4j
@Component
@Scope("prototype")
@ServerEndpoint(value = "/websocket/debugger/{token}/{deviceId}", configurator = MyEndpointConfigure.class)
public class DebuggerWebSocketServer implements Serializable {

    public static final long serialVersionUID = 1L;
    private static final ConcurrentHashMap<String, CopyOnWriteArraySet<DebuggerWebSocketServer>> CLIENT_MAP = new ConcurrentHashMap<>();
    @Getter
    private Session session;
    @Getter
    private String deviceId;

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        DebuggerWebSocketServer that = (DebuggerWebSocketServer) o;
        return Objects.equals(session, that.session) && Objects.equals(deviceId, that.deviceId);
    }

    @Override
    public int hashCode() {
        return Objects.hash(session, deviceId);
    }

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("token") String token, @PathParam("deviceId") String deviceId) throws Exception {
        Object loginId = StpUtil.getLoginIdByToken(token);
        if (loginId == null) {
            session.close();
            throw new SaTokenException("连接失败，无效Token：" + token);
        }

        this.session = session;
        this.deviceId = deviceId;
        CLIENT_MAP.computeIfAbsent(deviceId, k -> new CopyOnWriteArraySet<>()).add(this);
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        if (this.deviceId != null) {
            CLIENT_MAP.compute(deviceId, (k, v) -> {
                if (v != null) {
                    v.removeIf(server -> server.getSession().equals(session));
                    return v.isEmpty() ? null : v;
                }
                return null;
            });
        }
    }

    /**
     * 发生错误时调用
     */
    @OnError
    public void onError(Throwable error) {
        log.error("webSocket onError：{}", error.getMessage());
    }

    /**
     * 服务器接收到客户端消息时调用的方法
     */
    @OnMessage
    public void onMessage(String message) {
        log.info("webSocket OnMessage：{}", message);
    }

    /**
     * 向指定设备发送消息
     */
    public static void sendMessage(String targetDeviceId, String message) {
        CLIENT_MAP.computeIfPresent(targetDeviceId, (deviceId, servers) -> {
            servers.stream()
                    .filter(DebuggerWebSocketServer::isSessionOpen)
                    .forEach(server -> sendSingleMessage(server, message));
            return servers;
        });
    }

    private static boolean isSessionOpen(DebuggerWebSocketServer server) {
        return server.getSession().isOpen();
    }

    private static void sendSingleMessage(DebuggerWebSocketServer server, String message) {
        try {
            server.sendMessageToClient(message);
        } catch (IOException e) {
            if (!e.getMessage().contains("Channel is closed")) {
                log.error("向设备 [{}] 发送消息时发生错误: {}", server.getDeviceId(), e.getMessage());
            }
        }
    }

    /**
     * 向客户端发送消息
     */
    private synchronized void sendMessageToClient(String message) throws IOException {
        this.getSession().getBasicRemote().sendText(message);
    }
}